import numpy as np
import k_z_alg

z = 3
def test_assignment():
    points = np.array([[0,1], [100,2], [1,1]])
    centers = np.array([[0,0], [50,1]])
    alg = k_z_alg.K_z_algo(2, z, 1)
    alg.reassign_points(points, centers)
    print(alg.clusters)
    assert np.array_equal(alg.clusters, np.array([0,1,0]))

def test_optimize_centers():
    points = np.array([[0,1], [100,2], [1,1]])
    normalized_points = points/np.max(np.linalg.norm(points, axis=1))
    centers = np.array([[0,0], [50,1]])
    normalized_centers = centers/np.max(np.linalg.norm(points, axis=1))
    alg = k_z_alg.K_z_algo(2, z, 1)
    # new_centers = alg.optimize_centers(points, centers, gradient_steps=1000)
    new_centers = alg.optimize_centers(normalized_points, normalized_centers,lr=5, gradient_steps=10000)

def test_gradient():
    alg = k_z_alg.K_z_algo(2, z, 1)
    points = np.array([[1,1],[0,1]])
    center = [0,0]
    grad1 = alg.calculate_gradient(points, center)
    grad2 = alg.calculate_gradient(points, [0,1])
    grad3 = alg.calculate_gradient(points, [1,1])
    grad4 = alg.calculate_gradient(points, [0.5,1])
    assert(np.array_equal(grad4,np.array([0,0])))
    assert(grad3[0] == 1)
    assert(grad2[1] == grad3[1])
    assert(np.sign(grad2[0]) == -1)
    print("dummy")